This model card provides information about the Review Rating Prediction model, which predicts product review ratings on a scale of 1-5 stars based on review text.
Code
import osimport mlflowimport pandas as pdimport numpy as npfrom datetime import datetimeimport plotly.express as pximport plotly.graph_objects as goimport jsonimport plotly.io as piopio.renderers.default ="notebook"# Set MLflow tracking URIroot_dir = os.path.dirname(os.path.dirname(os.getcwd()))mlflow.set_tracking_uri(os.path.join(root_dir, 'mlruns'))# Helper functions for MLflow dataset handlingdef get_dataset_by_context(inputs, context):"""Get dataset information by context tag"""for input_data in inputs:for tag in input_data.tags:if tag.key =='mlflow.data.context'and tag.value == context:return input_data.datasetreturnNonedef get_schema_info(dataset):"""Extract schema information from dataset"""if dataset and dataset.schema: schema_dict = json.loads(dataset.schema)return schema_dict.get('mlflow_colspec', [])return []def get_row_count(dataset):"""Get number of rows from dataset profile"""if dataset and dataset.profile: profile_dict = json.loads(dataset.profile)return profile_dict.get('num_rows', 0)return0# Get model and run informationMODEL_NAME ="review_rating_model"client = mlflow.tracking.MlflowClient()# latest_version = client.get_latest_versions(MODEL_NAME, stages=["None"])[0]latest_version = client.get_model_version("review_rating_model", 9)run = mlflow.get_run(latest_version.run_id)# Get dataset informationvalidation_data = get_dataset_by_context(run.inputs.dataset_inputs, 'validation')training_data = get_dataset_by_context(run.inputs.dataset_inputs, 'training')raw_data = get_dataset_by_context(run.inputs.dataset_inputs, 'raw_data')# Display basic model infoprint(f"Model Version: {latest_version.version}")print(f"Last Updated: {datetime.fromtimestamp(run.info.start_time/1000).strftime('%Y-%m-%d %H:%M:%S')}")
Model Version: 9
Last Updated: 2024-11-21 00:58:51
Model Architecture
Type: RoBERTa-based Sequence Classification
Base Model: RoBERTa Base
Task: 5-class classification for review rating prediction
Output: Rating prediction (1-5 stars)
Dataset Overview
Code
# Display dataset statisticsprint("Dataset Statistics:")# print(f"Training samples: {get_row_count(training_data):,}")# print(f"Validation samples: {get_row_count(validation_data):,}")# print(f"Raw data samples: {get_row_count(raw_data):,}")print(f"Training samples: {454764:,}")print(f"Validation samples: {113690:,}")print(f"Raw data samples: {568454:,}")# Display schema informationif raw_data:print("\nFeature Information:")for col in get_schema_info(raw_data):print(f"- {col['name']} ({col['type']})")
Dataset Statistics:
Training samples: 454,764
Validation samples: 113,690
Raw data samples: 568,454
Model Performance
Code
# Get histories for loss metricsmetrics_to_plot = ['train_loss', 'eval_loss', 'learning_rate', 'grad_norm']histories = { metric: client.get_metric_history(run.info.run_id, metric)for metric in metrics_to_plot}# Create two subplotsfig = go.Figure()# Add traces for lossesfinal_metrics = run.data.metrics# Create evaluation loss plotfig = go.Figure()fig.add_trace(go.Scatter( x=steps, y=values, name='Evaluation Loss', mode='lines+markers'))fig.update_layout( title='Evaluation Loss During Training', xaxis_title='Step', yaxis_title='Loss', hovermode='x unified')fig.show()# Display final metricsfinal_metrics = run.data.metricsprint("\nFinal Metrics:")print(f"Training Loss: {final_metrics.get('train_loss', 'N/A'):.4f}")print(f"Evaluation Loss: {final_metrics.get('eval_loss', 'N/A'):.4f}")print(f"Training Runtime: {final_metrics.get('train_runtime', 'N/A'):.2f}s")print(f"Samples/second: {final_metrics.get('train_samples_per_second', 'N/A'):.2f}")
Final Metrics:
Training Loss: 0.6280
Evaluation Loss: 0.5381
Training Runtime: 23582.44s
Samples/second: 3.39
Model Information:
Model Location: file:///home/tathagat/workspace/projects/MLPE/tathagata-ai-839/review-rating/mlruns/863057453145536184/27e869ed62a3496fa284da0d956ee9e6/artifacts/model
Run ID: 27e869ed62a3496fa284da0d956ee9e6
Artifact URI: file:///home/tathagat/workspace/projects/MLPE/tathagata-ai-839/review-rating/mlruns/863057453145536184/27e869ed62a3496fa284da0d956ee9e6/artifacts